import pandas as pd
from tqdm import tqdm
import torch
from sentence_transformers import SentenceTransformer, util
from pathlib import Path

tqdm.pandas()

script_dir = Path(__file__).resolve().parent


def load_sentences():
    return pd.read_feather(script_dir / "../data/source/manifesto_corpus.feather")


def load_questions():
    return pd.read_csv(
        script_dir / "../data/ccr/moral_foundations_questions.csv", quotechar='"'
    )


def load_vignettes():
    return pd.read_csv(
        script_dir / "../data/ccr/vignettes.csv",
        quotechar='"',
        sep=";",
    )


def build_vectors(
    model, only_english=False, score_questions=True, score_vignettes=True
):
    ccr_vectors = {}
    moral_foundations_questions = load_questions()
    moral_foundations_vignettes = load_vignettes()

    if only_english:
        languages_of_questions = ["en"]
        languages_of_vignettes = ["en"]
    else:
        languages_of_questions = set(moral_foundations_questions["language"])
        languages_of_vignettes = set(moral_foundations_vignettes["language"])

    moral_foundations_in_questions = set(moral_foundations_questions["foundation"])
    moral_foundations_in_vignettes = set(moral_foundations_vignettes["foundation"])

    for language in languages_of_questions:
        ccr_vectors[language] = {}
        for moral_foundation in moral_foundations_in_questions:
            ccr_vectors[language][f"{moral_foundation}_virtue"] = model.encode(
                moral_foundations_questions[
                    (moral_foundations_questions["foundation"] == moral_foundation)
                    & (moral_foundations_questions["language"] == language)
                ].iloc[0][2:5]
            )

    for language in languages_of_vignettes:
        if not ccr_vectors[language]:
            ccr_vectors[language] = {}
        for moral_foundation in moral_foundations_in_vignettes:
            try:
                ccr_vectors[language][f"{moral_foundation}_vice"] = model.encode(
                    list(
                        moral_foundations_vignettes[
                            (
                                moral_foundations_vignettes["foundation"]
                                == moral_foundation
                            )
                            & (moral_foundations_vignettes["language"] == language)
                        ]["vignette"]
                    )
                )
            except IndexError:
                pass

    return ccr_vectors


def apply_ccr(
    input,
    model,
    ccr_vectors: dict,
    strategy: str,
):
    # Strategies to apply:
    # en_to_en Compare english sentence to english seed words
    # multi_to_en Multilingual sentences to english seed words (aligned embeddings)
    # multi_to_multi Multilingual sentences to multilingual seed words
    output_dict = {}
    if strategy == "multi_to_multi":
        input_language = input["language_iso"]
        ccr_vectors = ccr_vectors[input_language]

    else:
        ccr_vectors = ccr_vectors["en"]

    if strategy != "en_to_en":
        input_encoding = model.encode(input["text"])
    else:
        input_encoding = model.encode(input["text_en"])

    for moral_foundation in ccr_vectors:
        ## decide whether the mean or max value
        if len(ccr_vectors[moral_foundation]) == 0:
            pass
            # output_dict[f"ccr_{strategy}_{moral_foundation}"] = None
        else:
            output_dict[f"ccr_{strategy}_{moral_foundation}"] = torch.mean(
                util.cos_sim(input_encoding, ccr_vectors[moral_foundation])
            ).item()
    return output_dict


def get_transformer_model(multilingual: bool):
    if multilingual:
        return SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2")
    else:
        return SentenceTransformer("all-MiniLM-L6-v2")


def score_en_to_en():
    sentences = load_sentences()
    model = get_transformer_model(multilingual=False)
    ccr_vectors = build_vectors(model, only_english=True)
    df_en_to_en = sentences.progress_apply(
        apply_ccr,
        model=model,
        ccr_vectors=ccr_vectors,
        strategy="en_to_en",
        axis="columns",
        result_type="expand",
    )

    df_en_to_en_export = pd.concat(
        [df_en_to_en, sentences["id_for_project"]], axis="columns"
    )
    # df_en_to_en_export = df_en_to_en_export.drop("index", axis=1)
    df_en_to_en_export.reset_index().to_feather("data/ccr/ccr_en_to_en.feather")


def score_multi_to_en():
    sentences = load_sentences()
    model = get_transformer_model(multilingual=True)
    ccr_vectors = build_vectors(model, only_english=False)
    df = sentences.progress_apply(
        apply_ccr,
        model=model,
        ccr_vectors=ccr_vectors,
        strategy="multi_to_en",
        axis="columns",
        result_type="expand",
    )

    df = pd.concat([df, sentences["id_for_project"]], axis="columns")
    # df = df.drop("index", axis=1)
    df.reset_index().to_feather("data/ccr/ccr_multi_to_en.feather")


def score_multi_to_multi():
    sentences = load_sentences()
    model = get_transformer_model(multilingual=True)
    ccr_vectors = build_vectors(model, only_english=False)
    df = sentences.progress_apply(
        apply_ccr,
        model=model,
        ccr_vectors=ccr_vectors,
        strategy="multi_to_multi",
        axis="columns",
        result_type="expand",
    )
    df = pd.concat([df, sentences["id_for_project"]], axis="columns")
    # df = df.drop("index", axis=1)
    df.reset_index().to_feather("data/ccr/ccr_multi_to_multi.feather")


if __name__ == "__main__":
    print("Scoring english sentences with english embeddings")
    score_en_to_en()
    print("Scoring multilingual sentences with english embeddings")
    score_multi_to_en()
    print("Scoring multilingual sentences with multilingual embeddings")
    score_multi_to_multi()
